# Copyright 2022 Twitter, Inc and Zhendong Wang.
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from agents.helpers import SinusoidalPosEmb


class MLP(nn.Module):
    """
    MLP Model
    """
    def __init__(self,
                 state_dim,
                 action_dim,
                 device,
                 t_dim=16,
                 condition_dim=16,
                 condition_pos_embed=True):

        super(MLP, self).__init__()
        self.device = device
        self.condition_dim = condition_dim
        self.condition_pos_embed = condition_pos_embed

        self.time_mlp = nn.Sequential(
            SinusoidalPosEmb(t_dim),
            nn.Linear(t_dim, t_dim * 2),
            nn.Mish(),
            nn.Linear(t_dim * 2, t_dim),
        )
        if condition_pos_embed:
            self.condition_mlp = nn.Sequential(
                SinusoidalPosEmb(condition_dim),
                nn.Linear(condition_dim, condition_dim * 2),
                nn.Mish(),
                nn.Linear(condition_dim * 2, condition_dim),
            )
        else:
            self.condition_mlp = nn.Linear(1, condition_dim)

        input_dim = state_dim + action_dim + t_dim + condition_dim
        self.mid_layer = nn.Sequential(nn.Linear(input_dim, 256),
                                       nn.Mish(),
                                       nn.Linear(256, 256),
                                       nn.Mish(),
                                       nn.Linear(256, 256),
                                       nn.Mish())

        self.final_layer = nn.Linear(256, action_dim)

    def forward(self, x, time, state, condition=None):

        t = self.time_mlp(time)
        condition = torch.zeros([*x.shape[:-1], self.condition_dim], device=self.device) if condition is None else\
            (self.condition_mlp(condition) if self.condition_pos_embed else
             self.condition_mlp(condition.unsqueeze(-1)))
        x = torch.cat([x, t, state, condition], dim=1)
        x = self.mid_layer(x)

        return self.final_layer(x)

    def forward_multiple_conditions(self, x, time, state, conditions=(None,)):

        t = self.time_mlp(time)
        tmp = []
        for condition in conditions:
            condition = torch.zeros([*x.shape[:-1], self.condition_dim], device=self.device) if condition is None else \
                (self.condition_mlp(condition) if self.condition_pos_embed else
                 self.condition_mlp(condition.unsqueeze(-1)))
            tmp.append(condition)
        condition = torch.cat(tmp, dim=0)
        x = x.repeat([len(conditions), 1])
        t = t.repeat([len(conditions), 1])
        state = state.repeat([len(conditions), 1])

        x = torch.cat([x, t, state, condition], dim=1)
        x = self.mid_layer(x)
        x = self.final_layer(x)

        return torch.split(x, int(x.shape[0]/len(conditions)))


